{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 手写数字识别\n",
    "对于输入图像，预测图中的数字是0到9中的哪个的问题（10类别分类问题）。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### MNIST数据集\n",
    "**load_mnist** 函数以“**(训练图像，训练标签),(测试图像，测试标签)**”的形式返回读入的MNIST数据。\n",
    "load_mnist(normalize=True,flatten=True,one_hot_label=True) 可以设置3个参数。\n",
    "* normalize设置是否将输入图像正规化为0.0\\~1.0的值。<br>如果将该参数设置为False，则输入图像的像素会保持原来的0\\~255。<br>如果设置为True，函数内部会进行转换，将图像的各个像素值除以255，使得数据的值在0.0\\~1.0的范围内（正规化）。\n",
    "* flatten设置是否展开输入图像（变成一维数组）。<br>如果将该参数设置为False，则输入图像为1x28x28的三维数组；<br>若设置为True，则输入图像会保存为由784个元素构成的一维数组。\n",
    "* one_hot_label设置是否将标签保存为one-hot表示，就像[0,0,1,0,0,0,0,0,0,0]。<br>当one_hot_label为False时，只是像7，2这样简单保存正确解标签；<br>当one_hot_label为True时，标签则保存为one-hot表示。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 显示MNITS图像(训练图像的第一张)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n",
      "(784,)\n",
      "(28, 28)\n"
     ]
    }
   ],
   "source": [
    "import sys, os\n",
    "sys.path.append(os.pardir) # 把父目录加入到sys.path(Python的搜索模块的路径集)中，从而可以导入父目录下的任何目录\n",
    "import numpy as np\n",
    "from dataset.mnist import load_mnist\n",
    "from PIL import Image # 图像的显示使用PIL(Python Image Library)模块\n",
    "\n",
    "def img_show(img):\n",
    "    pil_img = Image.fromarray(np.uint8(img))# 把保存为Numpy数组的图像数据转换为PIL用的数据对象\n",
    "    pil_img.show()\n",
    "    \n",
    "(x_train, t_train), (x_test,t_test) = load_mnist(flatten=True, normalize=False) # flatten为True，表示展开输入图像（变成一位数组），如果为False，则输入图像为1*28*28的三维数组\n",
    "\n",
    "img = x_train[0]    # 第一张训练图像\n",
    "label = t_train[0]  # 第一个训练标签\n",
    "\n",
    "print(label) # 5\n",
    "\n",
    "print(img.shape)          # (784,)\n",
    "img = img.reshape(28, 28) # 把图像的形状变成原来的尺寸\n",
    "print(img.shape)          # (28, 28)\n",
    "\n",
    "img_show(img)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 神经网络的推理处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "# neuralnet_mnist.py\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "def sigmoid(x):\n",
    "    return 1 / (1+np.exp(-x))\n",
    "\n",
    "def softmax(a): # 输出层的激活函数\n",
    "    c = np.max(a)\n",
    "    exp_a = np.exp(a-c) # 溢出对策\n",
    "    sum_exp_a = np.sum(exp_a)\n",
    "    y = exp_a / sum_exp_a\n",
    "    \n",
    "    return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def get_data(): # 从数据集中读取数据\n",
    "    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False) # 读取数据集\n",
    "    return x_test, t_test\n",
    "\n",
    "def init_network(): # 读入samle_weight.pkl中的权重参数\n",
    "    with open(\"sample_weight.pkl\", 'rb') as f:\n",
    "        network = pickle.load(f)\n",
    "    return network\n",
    "\n",
    "def predict(network, x): # \n",
    "    W1, W2, W3 = network['W1'], network['W2'], network['W3']\n",
    "    b1, b2, b3 = network['b1'], network['b2'], network['b3']\n",
    "    a1 = np.dot(x, W1) + b1\n",
    "    z1 = sigmoid(a1)\n",
    "    a2 = np.dot(z1, W2) + b2\n",
    "    z2 = sigmoid(a2)\n",
    "    a3 = np.dot(z2, W3) + b3\n",
    "    y = softmax(a3)\n",
    "    return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy:0.9352\n"
     ]
    }
   ],
   "source": [
    "# 实现神经网络的推理处理\n",
    "x, t = get_data()\n",
    "network = init_network()\n",
    "\n",
    "accuracy_cnt = 0 # 识别精度，即能在多大程度上正确分类\n",
    "for i in range(len(x)):\n",
    "    y = predict(network, x[i])\n",
    "    p = np.argmax(y) # 获取概率最高的元素的索引\n",
    "    if p == t[i]:\n",
    "        accuracy_cnt += 1\n",
    "\n",
    "print(\"Accuracy:\" + str(float(accuracy_cnt) / len(x)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 批处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10000, 784)"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x, _ = get_data()\n",
    "network = init_network()\n",
    "W1, W2, W3 = network['W1'], network['W2'], network['W3']\n",
    "\n",
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(784,)"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(784, 50)"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W1.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50, 100)"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100, 10)"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W3.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过np.argmax()获取值最大的元素的索引。\n",
    "* 如果给定了参数 axis=0,代表沿着第0维，即列方向，找到值最大的元素的索引。\n",
    "* 如果给定了参数 axis=1,代表沿着第1维，即行方向，找到值最大的元素的索引。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 2 1 0]\n"
     ]
    }
   ],
   "source": [
    "x = np.array([[0.1, 0.8, 0.1], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3], [0.8, 0.1, 0.1]])\n",
    "y = np.argmax(x, axis=1) # 沿着第1维，即行方向，找到值最大的元素的索引\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy:0.9352\n"
     ]
    }
   ],
   "source": [
    "# 基于批处理实现代码\n",
    "x, t = get_data()\n",
    "network = init_network()\n",
    "\n",
    "batch_size = 100 # 批数量\n",
    "accuracy_cnt = 0\n",
    "\n",
    "for i in range(0, len(x), batch_size):\n",
    "    x_batch = x[i:i+batch_size]\n",
    "    y_batch = predict(network, x_batch)\n",
    "    p = np.argmax(y_batch, axis=1)\n",
    "    accuracy_cnt += np.sum(p == t[i:i+batch_size])\n",
    "\n",
    "print(\"Accuracy:\" + str(float(accuracy_cnt / len(x))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 小节\n",
    "\n",
    "\n",
    "* 神经网络中的激活函数使用平滑变化的 sigmoid 函数或者 ReLU 函数。\n",
    "* 通过巧妙地使用 NumPy 多维数组，可以高效地实现神经网络。\n",
    "* 机器学习的问题大体上可以分为回归问题和分类问题。\n",
    "* 关于输出层的激活函数，回归问题中一般用恒等函数，分类问题中一般用 softmax 函数。\n",
    "* 分类问题中，输出层的神经元的数量设置为要分类的类别数。\n",
    "* 输入数据的集合称为批。通过以批为单位进行推理处理，能够实现高效的运算。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
